import json, os
import random
import traceback
import numpy as np
from pycocotools import mask as mask_util
from tqdm import tqdm
from huggingface_hub import hf_hub_download
import ast
from PIL import Image, ImageDraw, ImageFont
from PIL import ImageColor
import xml.etree.ElementTree as ET

additional_colors = [colorname for (colorname, colorcode) in ImageColor.colormap.items()]


def find_longest_segment(lst):
    if not lst:
        return []

    longest = []
    current = [lst[0]]

    for i in range(1, len(lst)):
        if lst[i] == lst[i - 1] + 1:
            # The current number is consecutive; add it to the current segment.
            current.append(lst[i])
        else:
            # Not consecutive: check if the current segment is the longest found so far.
            if len(current) > len(longest):
                longest = current
            # Start a new segment with the current number.
            current = [lst[i]]

    # Final check in case the longest segment is at the end of the list.
    if len(current) > len(longest):
        longest = current

    return longest


import numpy as np


def get_bbox_from_mask(mask):
    # Ensure mask is a NumPy array of type bool
    mask = mask.astype(bool)

    # Find non-zero (True) indices
    ys, xs = np.where(mask)

    if len(xs) == 0 or len(ys) == 0:
        return None     # No object found

    # Get min/max coords
    x_min, x_max = int(xs.min()), int(xs.max())
    y_min, y_max = int(ys.min()), int(ys.max())

    return (x_min, y_min, x_max, y_max)


# @title Parsing JSON output
def parse_json(json_output):
    # Parsing out the markdown fencing
    lines = json_output.splitlines()
    for i, line in enumerate(lines):
        if line == "```json":
            json_output = "\n".join(lines[i + 1:])     # Remove everything before "```json"
            json_output = json_output.split("```")[0]     # Remove everything after the closing "```"
            break     # Exit the loop once "```json" is found
    return json_output


def plot_bounding_boxes(im, bounding_boxes, input_width, input_height):
    """
    Plots bounding boxes on an image with markers for each a name, using PIL, normalized coordinates, and different colors.

    Args:
        img_path: The path to the image file.
        bounding_boxes: A list of bounding boxes containing the name of the object
         and their positions in normalized [y1 x1 y2 x2] format.
    """

    # Load the image
    img = im
    width, height = img.size
    print(img.size)
    # Create a drawing object
    draw = ImageDraw.Draw(img)

    # Define a list of colors
    colors = [
        'red',
        'green',
        'blue',
        'yellow',
        'orange',
        'pink',
        'purple',
        'brown',
        'gray',
        'beige',
        'turquoise',
        'cyan',
        'magenta',
        'lime',
        'navy',
        'maroon',
        'teal',
        'olive',
        'coral',
        'lavender',
        'violet',
        'gold',
        'silver',
    ] + additional_colors

    # Parsing out the markdown fencing
    bounding_boxes = parse_json(bounding_boxes)

    font = ImageFont.truetype("NotoSansCJK-Regular.ttc", size=14)

    try:
        json_output = ast.literal_eval(bounding_boxes)
    except Exception as e:
        end_idx = bounding_boxes.rfind('"}') + len('"}')
        truncated_text = bounding_boxes[:end_idx] + "]"
        json_output = ast.literal_eval(truncated_text)

    # Iterate over the bounding boxes
    for i, bounding_box in enumerate(json_output):
        # Select a color from the list
        color = colors[i % len(colors)]

        # Convert normalized coordinates to absolute coordinates
        abs_y1 = int(bounding_box["bbox_2d"][1] / input_height * height)
        abs_x1 = int(bounding_box["bbox_2d"][0] / input_width * width)
        abs_y2 = int(bounding_box["bbox_2d"][3] / input_height * height)
        abs_x2 = int(bounding_box["bbox_2d"][2] / input_width * width)

        if abs_x1 > abs_x2:
            abs_x1, abs_x2 = abs_x2, abs_x1

        if abs_y1 > abs_y2:
            abs_y1, abs_y2 = abs_y2, abs_y1

        # Draw the bounding box
        draw.rectangle(((abs_x1, abs_y1), (abs_x2, abs_y2)), outline=color, width=4)

        # Draw the text
        if "label" in bounding_box:
            draw.text((abs_x1 + 8, abs_y1 + 6), bounding_box["label"], fill=color, font=font)

    # Display the image
    # img.show()
    # display(img)
    img.save("vis/output.jpg")


if __name__ == '__main__':
    # split = 'train'

    # refer_data = json.load(open(f'scanrefer/ScanRefer_filtered_{split}.json', 'r'))

    # train_scan_ids = np.loadtxt(f'scanrefer/ScanRefer_filtered_{split}.txt', dtype=str)

    # Download the file (if not already cached) and get its local path
    file_path = hf_hub_download(
        repo_id="",
        filename="mask_dict.json",
        repo_type="dataset",
     # cache_dir="./"     # optional: specify your custom cache directory
    )
    # print(file_path)
    mask_dict = json.load(open(file_path))

    meta = {}
    correct = 0

    # train_scan_ids = train_scan_ids[:10]

    preds = json.load(open('outputs/scanrefer_1/checkpoint-3000/preds.json', 'r'))
    for pred in tqdm(preds):
        # print(pred)

        scan_id = pred['video_id']

        try:
            pred_s = int(pred['pred'][0] * 2)
            pred_e = int(pred['pred'][1] * 2)
            gt_s = int(pred['gt'][0] * 2)
            gt_e = int(pred['gt'][1] * 2)

            pred_inds = list(range(pred_s, pred_e + 1))
            # print(pred_inds)

            selected_idx = pred_inds[len(pred_inds) // 2]
            # selected_idx = random.choice(pred_inds)

            if gt_s <= selected_idx <= gt_e:
                correct += 1
        except Exception as e:
            print(pred)
            traceback.print_exc()

        # root_dir = f'/data2/datasets/Sa2VA-Training/video_datas/scannet/JPEGImages/{scan_id}'
        # frames = os.listdir(root_dir)
        # frames = [frame.split('.')[0] for frame in frames]
        # frames = sorted(frames, key=lambda x: int(x))

        # selected_mask = mask_util.decode(frame_masks[selected_idx])
        # selected_mask = selected_mask.astype(bool)

        # bbox = get_bbox_from_mask(selected_mask)
        # print(bbox)
        # bbox = list(bbox)

        # meta[scan_id]["bbox_2d"].append(bbox)
        # meta[scan_id]["sentences"].append(data['description'])
        # meta[scan_id]["image_path"].append(img_path)

    print(correct, len(preds), correct / len(preds))

    # break

    # for scan_id in tqdm(train_scan_ids):
    #     root_dir = f'/data2/datasets/Sa2VA-Training/video_datas/scannet/JPEGImages/{scan_id}'
    #     frames = os.listdir(root_dir)
    #     frames = [frame.split('.')[0] for frame in frames]
    #     frames = sorted(frames, key=lambda x: int(x))

    #     # print(len(frames), frames[:10])
    #     fps = 3
    #     meta[scan_id] = {"duration": len(frames) / fps, "image_path": [], "sentences": [], "bbox_2d": []}

    #     idx = 0
    #     for data in refer_data:
    #         if data['scene_id'] == scan_id:
    #             obj_id = int(data['object_id']) + 1
    #             frame_masks = mask_dict[f'{scan_id}_{obj_id}']
    #             roi_frames = []

    #             for frame_idx in range(len(frames)):
    #                 gt_mask = mask_util.decode(frame_masks[frame_idx])
    #                 gt_mask = gt_mask.astype(bool)
    #                 if gt_mask.sum() > 500:
    #                     roi_frames.append(frame_idx)

    #             if len(roi_frames) == 0:
    #                 continue

    #             selected_idx = random.choice(roi_frames)
    #             selected_mask = mask_util.decode(frame_masks[selected_idx])
    #             selected_mask = selected_mask.astype(bool)

    #             bbox = get_bbox_from_mask(selected_mask)
    #             # print(bbox)
    #             bbox = list(bbox)
    #             bbox_dict = {"bbox_2d": [bbox[0], bbox[1], bbox[2], bbox[3]], "label": data['description']}
    #             bbox_str = json.dumps([bbox_dict], indent=4)
    #             # print(bbox_str)
    #             img_path = os.path.join(root_dir, frames[selected_idx] + '.jpg')
    #             # img = Image.open(img_path)
    #             # plot_bounding_boxes(img, bbox_str, img.size[0], img.size[1])
    #             # img.show()

    #             meta[scan_id]["bbox_2d"].append(bbox)
    #             meta[scan_id]["sentences"].append(data['description'])
    #             meta[scan_id]["image_path"].append(img_path)

    # break

    # break

